#!/usr/bin/python3

import argparse
import numpy as np
import time
import matplotlib.pyplot as plt


parser = argparse.ArgumentParser(description='Thermal evolution model of the EL planetesimal with breakup')
parser.add_argument("-b","--breakuptime",type=int,help='Time of the EL parent breakup',default=180)
parser.add_argument("-t","--totalSimualtionTime",type=float,help='Total simulation time',default=200)
parser.add_argument("-p", "--plot",action="store_true",help="Show plots")
parser.add_argument("-s", "--save",action="store_true",help="Save temperature curves")
parser.add_argument("-o","--outsidetemperature",type=int,help='Temperature at surface',default=200)

args = parser.parse_args()


def loadKappa():
    global _kappaData
    _kappaData=np.loadtxt('kappaVStemp.dat')

def getKappa(T):
    global _kappaData
    return np.interp(T, _kappaData[:,0], _kappaData[:,1])

def loadCoreTemp():
    global _coreTemp
    _coreTemp=np.loadtxt('trieloffFig5_coreTemp.dat')

def getCoreTemp(time):
    global _coreTemp
    return np.interp(np.log10(time), _coreTemp[:,0], _coreTemp[:,1])

kappaTempDependentSpherical=True
kappaTempDependent=False
breakupTime=args.breakuptime
totalSimulationTime=args.totalSimualtionTime
surfTemp=args.outsidetemperature


## kappa is temperature dependent and
## diff equation with sherical symmetry
if (kappaTempDependentSpherical==True):
    def advanceTemp(T, R, kappa, rho, C, dt):
        Tnew = T
        rNpt=len(R)
        for i in reversed(range(1, rNpt-1)):
            dr=R[i+1]-R[i]
            Tnew[i] = T[i] + dt/rho/C*(
                2*kappa[i]/R[i]*(T[i+1]-T[i])/dr +
                (kappa[i+1]-kappa[i])/dr*(T[i+1]-T[i])/dr +
                kappa[i]*(T[i+1] - 2*T[i] + T[i-1])/dr/dr
                )
        return Tnew
        
'''
## kappa is not temperature dependent
def advanceTemp(T, R, kappa, rho, C, dt):
    Tnew = T
    rNpt=len(R)
    for i in reversed(range(1, rNpt-1)):
        dr=R[i+1]-R[i]
        Tnew[i] = T[i] + kappa/rho/C*dt * (T[i+1] - 2*T[i] + T[i-1])/dr/dr
    return Tnew
'''

## kappa is temperature dependent
if (kappaTempDependent==True):
    def advanceTemp(T, R, kappa, rho, C, dt):
        Tnew = T
        rNpt=len(R)
        for i in reversed(range(1, rNpt-1)):
            dr=R[i+1]-R[i]
            Tnew[i] = T[i] + dt/rho/C * \
              ( (kappa[i+1]-kappa[i])/dr * (T[i+1]-T[i])/dr +
                kappa[i]* (T[i+1] - 2*T[i] + T[i-1])/dr/dr )
        return Tnew

loadKappa()
loadCoreTemp()

'''
T=np.linspace(0,1400)
kappa=getKappa(T)
plt.plot(T, kappa)
plt.show()
exit(0)
'''
'''
logtime=np.linspace(0,3)
coreTemp=getCoreTemp(10**logtime)
plt.plot(10**logtime, coreTemp)
plt.scatter(10**_coreTemp[:,0], _coreTemp[:,1])
plt.xscale('log')
plt.show()
exit(0)
'''

rMax = 130.e3
rNpt = 260
T = np.ones(rNpt)*1000
Tnew  = T
T[rNpt-1]=surfTemp
R = np.linspace(0,rMax,rNpt)

rho = 3500*0.5
#kappa = 1.5
C = 650

dt=3600*24*365*0.1
t=0
dr=R[1]-R[0]

elDepth = 17.e3
k=np.argmin( ((rMax-R)-elDepth)**2 )
print('EL depth at radius', R[k]/1e3, "km")
'''
plt.ion()
fig = plt.figure()
ax = fig.add_subplot(111)
line1, = ax.plot(np.flip(R), T, 'r-') # Returns a tuple of line objects, thus the comma
'''

#### inital temperature of elements using Fig5 of Trieloff et al. 22 Icarus
# Temperature VS depth at 5 Mry
iniTemp=np.array([266.881836, 414.051640, 649.795894, 874.638665, 1039.523353, 1126.735101, 1149.150613, 1181.855033, 1193.865982, 1201.783737, 1206.801602, 1204.801602])+90
# Temperature VS depth at 10 Mry
#iniTemp=np.array([238.177874, 322.776573, 469.414317, 647.071584, 847.288503, 985.466377, 1013.665944, 1078.524946, 1126.464208, 1168.763557, 1196.963124, 1264.642082])+150
depth=np.array([1, 4, 7, 11, 15, 17, 19, 23, 27, 30, 33])*1.e3
depth=np.append(depth,rMax)
print(depth)
## convert depth to radius
rr=rMax-depth

#jj=np.argwhere(R<=33.e3)
ynew = np.interp(rMax-R, depth, iniTemp)
'''
from scipy.interpolate import CubicSpline
spl = CubicSpline(depth, iniTemp)
ynew = spl(rMax-R)
'''
T=ynew
T[rNpt-1]=surfTemp

'''
#plt.plot(R,ynew)
plt.plot(R, T)
plt.scatter(rr, iniTemp)
#plt.xscale('log')
plt.show()
'''


if (args.plot):
    plt.ion()
    fig = plt.figure()
    ax = fig.add_subplot(111)
    logTime17, temp17 = np.loadtxt('trieloffFig5_depth17km.dat', unpack=True)
    plt.plot(10**logTime17,temp17, label='Trieloff 22')
    plt.legend()
    plt.xscale('log')
    plt.xlim((5,200))
    plt.xlabel("Time after CAIs (Myr)")
    plt.ylabel("Temperature (K)")
    line1, = ax.plot(t/3600/24/365/1.e6 + 5, T[k]) # Returns a tuple of line objects, thus the comma
    line2, = ax.plot(t/3600/24/365/1.e6 + 5, T[k], 'o') # Returns a tuple of l

tV=[]
TV=[]
j=0
###### Thermal evolution of the EL planetesimal
while t<3600*24*365*totalSimulationTime*1e6:
    #kappa=1. ## constant kappa
    #kappa = np.ones(rNpt)*1.3
    # FOR Cartesian
    #kappa = getKappa(T)*0.15
    #kappa[rNpt-2:rNpt-1]=0.1
    # for spherical symmetry
    kappa = getKappa(T)*0.15
    kappa[rNpt-2:rNpt-1]=0.017
    T[0:10] = getCoreTemp(t/(3600*24*365*1e6) + 5.)
    Tnew = advanceTemp(T, R, kappa, rho, C, dt)
    T=Tnew
    t+=dt
    j+=1

    if j>100:
        if (args.plot):
            line1.set_ydata(TV)#[k])
            line1.set_xdata(tV)#t/3600/24/365/1.e6 + 5)
            line2.set_ydata(T[k])
            line2.set_xdata(t/3600/24/365/1.e6 + 5)
            fig.canvas.draw()
            fig.canvas.flush_events()

        print(t/3600/24/365/1e6, Tnew[rNpt-1], Tnew[rNpt-2], Tnew[rNpt-30], Tnew[k])
        j=0
        tV.append(t/3600/24/365/1.e6 + 5)
        TV.append(Tnew[k])

    if t>(100.*365*24*3600):
        dt=3600*24*365*1.0e3

    if t>(1000.*365*24*3600):
        dt=3600*24*365*5.0e3

    if t>((breakupTime-5)*1e6*365*24*3600):
        break    

# Extracttion of Athor family progenitor
athorFamProgRadius = 32.e3 ## in m
kAFP=np.argmin( ((rMax-R)-athorFamProgRadius)**2 )
print('Athor family progenitor is at radius', R[kAFP]/1e3, "km", 'i.e. index of', kAFP)
print('Reminder: EL depth at radius', R[k]/1e3, "km", 'i.e. index of', k)
rAFP=R[kAFP:rNpt-1]-R[kAFP]
print(rAFP)
TAFP=T[kAFP:rNpt-1]
print(TAFP)
#while t<3600*24*365*100e6:
k=k-kAFP
T=TAFP
R=rAFP
rNpt=len(R)
###### Thermal evolution of Athor family progenitor
while t<3600*24*365*totalSimulationTime*1e6:
    #kappa=1. ## constant kappa
    #kappa = np.ones(rNpt)*1.3
    # FOR Cartesian
    #kappa = getKappa(T)*0.15
    #kappa[rNpt-2:rNpt-1]=0.1
    # for spherical symmetry
    kappa = getKappa(T)*0.15
    kappa[rNpt-2:rNpt-1]=0.017
    T[0] = T[1]
    T[rNpt-1]=surfTemp
    Tnew = advanceTemp(T, R, kappa, rho, C, dt)
    T=Tnew
    t+=dt
    j+=1

    if j>100:
        if (args.plot):
            line1.set_ydata(TV)#[k])
            line1.set_xdata(tV)#t/3600/24/365/1.e6 + 5)
            line2.set_ydata(T[k])
            line2.set_xdata(t/3600/24/365/1.e6 + 5)
            fig.canvas.draw()
            fig.canvas.flush_events()

        print(t/3600/24/365/1e6, Tnew[rNpt-1], Tnew[rNpt-2], Tnew[rNpt-30], Tnew[k])
        j=0
        tV.append(t/3600/24/365/1.e6 + 5)
        TV.append(Tnew[k])

if (args.save):
    fname='eltemp-'+str("{:03d}".format(int(breakupTime)))+'.dat'
    np.savetxt(fname,np.transpose([tV,TV]))

exit(0)

dtype1 = np.dtype([('mateorite', '|S20'), ('thermochronometer', '|S7'), \
                       ('age', 'f8'), ('ageUncertainty','f8'), \
                       ('temperature', 'f8'), ('temperatureUncertainty','f8')])
metName, tcName, mAge, sAge, mTemp, msTemp = np.loadtxt('tagesV2.dat',usecols=(0,1,2,4,5,7), dtype=dtype1,unpack=True)

plt.plot(tV,TV,label="Avdellidou 23")
plt.errorbar(4567.-mAge, mTemp, fmt='o', xerr=sAge, yerr=msTemp)
m=np.argwhere((4567-mAge)>30)

#plt.scatter(4567-mAge[m], modelT, marker='^', zorder=10)

print('now MCing with N=',len(m), "data point")

XX=[]
for i in range(10000):
    ageChron = 4567-mAge[m] + np.random.normal(0, 1, len(m))*sAge[m]
    temChron = mTemp[m] + np.random.normal(0, 1, len(m))*msTemp[m]
    temChron[-3] = np.random.uniform(400, 550, 1)
    modelT=np.interp(ageChron, tV, TV)
    '''
    plt.scatter(ageChron[-3], temChron[-3], marker='.', zorder=10)
    fig.canvas.draw()
    fig.canvas.flush_events()
    '''
    mychi2 = np.sum((modelT-temChron)**2/sAge[m]**2)
    XX.append(mychi2)
XX=np.array(XX)

'''
print("-----")
print(np.mean(XX),np.median(XX))
mychi2=np.median(XX)
print("Asuming that the EL parent body broke at ", breakupTime, "Myr after the CAI") 
print("We get chi2 =", mychi2, "nu =", len(m))
# with only thermochronometers > 30 Myr after CAI
print("and a distance from best case", (mychi2-28.6)/np.sqrt(2*(len(m)-1)),"σ")
# with all thermochronometers 
#print("distance from best case", (chi2-1.726235604495213)/np.sqrt(2*len(m)),"σ")
from scipy.stats import chi2
mychi2-=28.6
if (mychi2<=0):
    mychi2=0.0
print(1-chi2.cdf(mychi2, len(m)-1))
'''
plt.ioff()
plt.show()

plt.hist(XX, bins=100)
plt.show()

XX=np.sort(XX)

import pickle
'''
fff=open("best.bin", "wb")
pickle.dump(XX, fff)
fff.close()
'''
fff=open("best.bin", "rb")
XXref=pickle.load(fff)
fff.close()
import csaps
plt.plot(XX,np.arange(1,len(XX)+1)/len(XX))

sp=csaps.csaps(XX,np.arange(1,len(XX)+1)/len(XX))
chi2thr=2339.
print(sp(chi2thr))
plt.scatter(chi2thr,sp(chi2thr))

#plt.plot(XXref,np.arange(1,len(XXref)+1)/len(XXref))
#plt.xscale('log')
#plt.hist(XX, bins=40)
plt.show()
